{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Textual inversion"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The [StableDiffusionPipeline](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline) supports textual inversion, a technique that enables a model like Stable Diffusion to learn a new concept from just a few sample images. This gives you more control over the generated images and allows you to tailor the model towards specific concepts. You can get started quickly with a collection of community created concepts in the [Stable Diffusion Conceptualizer](https://huggingface.co/spaces/sd-concepts-library/stable-diffusion-conceptualizer).\n",
    "\n",
    "This guide will show you how to run inference with textual inversion using a pre-learned concept from the Stable Diffusion Conceptualizer. If you're interested in teaching a model new concepts with textual inversion, take a look at the [Textual Inversion](https://huggingface.co/docs/diffusers/main/en/using-diffusers/./training/text_inversion) training guide.\n",
    "\n",
    "Login to your Hugging Face account:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Import the necessary libraries, and create a helper function to visualize the generated images:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "\n",
    "import PIL\n",
    "from PIL import Image\n",
    "\n",
    "from diffusers import StableDiffusionPipeline\n",
    "from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer\n",
    "\n",
    "\n",
    "def image_grid(imgs, rows, cols):\n",
    "    assert len(imgs) == rows * cols\n",
    "\n",
    "    w, h = imgs[0].size\n",
    "    grid = Image.new(\"RGB\", size=(cols * w, rows * h))\n",
    "    grid_w, grid_h = grid.size\n",
    "\n",
    "    for i, img in enumerate(imgs):\n",
    "        grid.paste(img, box=(i % cols * w, i // cols * h))\n",
    "    return grid"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Pick a Stable Diffusion checkpoint and a pre-learned concept from the [Stable Diffusion Conceptualizer](https://huggingface.co/spaces/sd-concepts-library/stable-diffusion-conceptualizer):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pretrained_model_name_or_path = \"runwayml/stable-diffusion-v1-5\"\n",
    "repo_id_embeds = \"sd-concepts-library/cat-toy\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now you can load a pipeline, and pass the pre-learned concept to it:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pipeline = StableDiffusionPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16).to(\"cuda\")\n",
    "\n",
    "pipeline.load_textual_inversion(repo_id_embeds)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create a prompt with the pre-learned concept by using the special placeholder token `<cat-toy>`, and choose the number of samples and rows of images you'd like to generate:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = \"a grafitti in a favela wall with a <cat-toy> on it\"\n",
    "\n",
    "num_samples = 2\n",
    "num_rows = 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then run the pipeline (feel free to adjust the parameters like `num_inference_steps` and `guidance_scale` to see how they affect image quality), save the generated images and visualize them with the helper function you created at the beginning:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_images = []\n",
    "for _ in range(num_rows):\n",
    "    images = pipe(prompt, num_images_per_prompt=num_samples, num_inference_steps=50, guidance_scale=7.5).images\n",
    "    all_images.extend(images)\n",
    "\n",
    "grid = image_grid(all_images, num_samples, num_rows)\n",
    "grid"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"flex justify-center\">\n",
    "    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/textual_inversion_inference.png\">\n",
    "</div>"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 4
}
